wenkai commited on
Commit
5b42959
1 Parent(s): 0f66ac3

Update lavis/models/protein_models/protein_function_opt.py

Browse files
lavis/models/protein_models/protein_function_opt.py CHANGED
@@ -1,455 +1,458 @@
1
- """
2
- Copyright (c) 2023, salesforce.com, inc.
3
- All rights reserved.
4
- SPDX-License-Identifier: BSD-3-Clause
5
- For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- """
7
- import logging
8
- from packaging import version
9
-
10
- import torch
11
- from torch.cuda.amp import autocast as autocast
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
-
15
- from lavis.common.registry import registry
16
- from lavis.models.blip2_models.blip2 import Blip2Base, Blip2ProteinBase, disabled_train
17
- from transformers import AutoTokenizer, LlamaTokenizer, MistralForCausalLM, MistralConfig
18
- import transformers
19
- import esm
20
- import random
21
- from lavis.models.base_model import FAPMConfig
22
-
23
-
24
- def comb(s):
25
- s_list = [i.strip() for i in s.split(';')]
26
- random.shuffle(s_list)
27
- return '; '.join(s_list)
28
-
29
-
30
- def process_text(txts, probs):
31
- res = dict()
32
- for txt, prob in zip(txts, probs):
33
- txt_sep = [x.strip() for x in txt.split(';')]
34
- for txt_sub in txt_sep:
35
- txt_sub = txt_sub.replace('|', '')
36
- if txt_sub not in res and txt_sub != '':
37
- res[txt_sub] = round(prob.item(),3)
38
- return '; '.join([str((k, v)) for k, v in res.items()])
39
-
40
-
41
-
42
- @registry.register_model("blip2_protein_mistral")
43
- class Blip2ProteinMistral(Blip2ProteinBase):
44
-
45
- PRETRAINED_MODEL_CONFIG_DICT = {
46
- "pretrain_protein_mistral7b": "configs/models/blip2/pretrain_protein_mistral7b.yaml",
47
- }
48
- config_class = FAPMConfig
49
-
50
- def __init__(
51
- self,
52
- config,
53
- num_query_token=32,
54
- prompt="",
55
- max_txt_len=128,
56
- max_protein_len=128,
57
- apply_lemmatizer=False,
58
- get_eval=False,
59
- esm_size='650m'
60
- ):
61
- """
62
- apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
63
- """
64
- super().__init__(config)
65
- transformers_version = version.parse(transformers.__version__)
66
- assert transformers_version >= version.parse("4.27"), "BLIP-2 mistral requires transformers>=4.27"
67
-
68
- self.tokenizer = self.init_tokenizer()
69
- '''
70
- self.ln_vision, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
71
- if freeze_vit:
72
- self.ln_vision = self.ln_vision.half()
73
- self.visual_encoder = alphabet.get_batch_converter(truncation_seq_length=max_protein_len)
74
- self.padding_idx = alphabet.padding_idx
75
- self.vis_layers = self.ln_vision.num_layers
76
-
77
- if freeze_vit:
78
- for name, param in self.ln_vision.named_parameters():
79
- param.requires_grad = False
80
- self.ln_vision = self.ln_vision.eval()
81
- self.ln_vision.train = disabled_train
82
- logging.info("freeze vision encoder")
83
- else:
84
- for name, param in self.ln_vision.named_parameters():
85
- if 'contact_head' in name or 'emb_layer_norm_after' in name or 'lm_head' in name:
86
- param.requires_grad = False
87
- '''
88
- if esm_size == '650m':
89
- self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, 1280)
90
- elif esm_size == '3b':
91
- self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, 2560)
92
- self.Qformer.cls = None
93
- self.Qformer.bert.embeddings.word_embeddings = None
94
- self.Qformer.bert.embeddings.position_embeddings = None
95
- for layer in self.Qformer.bert.encoder.layer:
96
- layer.output = None
97
- layer.intermediate = None
98
-
99
- self.mistral_tokenizer = LlamaTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
100
- # self.mistral_tokenizer = LlamaTokenizer.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", use_fast=False)
101
- # configuration = MistralConfig()
102
- self.mistral_tokenizer.pad_token = '<pad>'
103
- self.mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
104
- # self.mistral_model = MistralForCausalLM.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
105
- # self.mistral_model = MistralForCausalLM(configuration)
106
- for name, param in self.mistral_model.named_parameters():
107
- param.requires_grad = False
108
- #self.mistral_model.lm_head = self.mistral_model.lm_head.float()
109
- #for param in self.mistral_model.lm_head.parameters():
110
- # param.requires_grad = True
111
-
112
- #self.eos_token_id = self.mistral_tokenizer(
113
- # "\n", add_special_tokens=False
114
- #).input_ids[0]
115
- self.eos_token_id = self.mistral_tokenizer(
116
- "\n", add_special_tokens=False
117
- ).input_ids[1]
118
- print(f"LLM hidden size: {self.mistral_model.config.hidden_size}")
119
- self.opt_proj = nn.Linear(
120
- self.Qformer.config.hidden_size, self.mistral_model.config.hidden_size
121
- )
122
-
123
- self.max_txt_len = max_txt_len
124
- self.prompt = prompt
125
- prompt_tokens = self.mistral_tokenizer(self.prompt, return_tensors="pt")
126
- self.prompt_length = prompt_tokens.attention_mask.sum(1)
127
-
128
- self._apply_lemmatizer = apply_lemmatizer
129
- self._lemmatizer = None
130
-
131
- def forward(self, samples):
132
- '''
133
- image = samples["image"]
134
- image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
135
-
136
- with self.maybe_autocast():
137
- _, _, batch_tokens = self.visual_encoder(image)
138
- image_embeds = self.ln_vision(batch_tokens.to(self.device), repr_layers=[self.vis_layers], return_contacts=True)["representations"][self.vis_layers].contiguous()
139
- '''
140
- image_embeds = samples["image"]
141
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
142
- self.device
143
- )
144
-
145
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
146
- query_output = self.Qformer.bert(
147
- query_embeds=query_tokens,
148
- encoder_hidden_states=image_embeds,
149
- encoder_attention_mask=image_atts,
150
- return_dict=True,
151
- )
152
-
153
- inputs_mistral = self.opt_proj(query_output.last_hidden_state)
154
-
155
- #torch.save(query_output.last_hidden_state, '/cluster/home/wenkai/LAVIS/output/mf_bp_cc/query_output_mf/{}.pt'.format(samples['name'][0]))
156
- #torch.save(inputs_mistral, '/cluster/home/wenkai/LAVIS/output/mf_bp_cc/inputs_mistral_mf/{}.pt'.format(samples['name'][0]))
157
-
158
- atts_mistral = torch.ones(inputs_mistral.size()[:-1], dtype=torch.long).to(self.device)
159
-
160
- # prompt
161
- prompt = samples["prompt"]
162
- prompt_tokens = self.mistral_tokenizer(prompt, padding="longest", return_tensors="pt")
163
- prompt_length = prompt_tokens.attention_mask.sum(1)
164
-
165
- self.mistral_tokenizer.padding_side = "right"
166
-
167
- text = [p+' '+comb(t) + "\n" for p, t in zip(prompt, samples["text_input"])]
168
- text = [p+' '+ t + "\n" for p, t in zip(prompt, samples["text_input"])]
169
-
170
- mistral_tokens = self.mistral_tokenizer(
171
- text,
172
- return_tensors="pt",
173
- padding="longest",
174
- truncation=True,
175
- max_length=self.max_txt_len,
176
- ).to(self.device)
177
-
178
- targets = mistral_tokens.input_ids.masked_fill(
179
- mistral_tokens.input_ids == self.mistral_tokenizer.pad_token_id, -100
180
- )
181
-
182
- for i, pl in enumerate(prompt_length):
183
- targets[i, :pl] = -100 # do not apply loss to the prompt
184
- #print(prompt_tokens, '\n', mistral_tokens, '\n', prompt_length)
185
-
186
- #if self.prompt:
187
- # targets[:, : self.prompt_length] = -100 # do not apply loss to the prompt
188
-
189
- empty_targets = (
190
- torch.ones(atts_mistral.size(), dtype=torch.long).to(self.device).fill_(-100)
191
- )
192
- targets = torch.cat([empty_targets, targets], dim=1)
193
-
194
- #inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
195
- inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
196
- inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
197
- attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
198
-
199
- with self.maybe_autocast():
200
- outputs = self.mistral_model(
201
- inputs_embeds=inputs_embeds,
202
- attention_mask=attention_mask,
203
- return_dict=True,
204
- labels=targets,
205
- )
206
- loss = outputs.loss
207
- return {"loss": loss}
208
-
209
- @torch.no_grad()
210
- def generate(
211
- self,
212
- samples,
213
- # use_nucleus_sampling=False,
214
- num_beams=15,
215
- max_length=32,
216
- min_length=1,
217
- # top_p=0.9,
218
- repetition_penalty=1.0,
219
- length_penalty=0.,
220
- num_captions=10,
221
- temperature=1,
222
- ):
223
- """
224
- Args:
225
- samples (dict): A dictionary containing the following keys:
226
- - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
227
- use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
228
- num_beams (int): Number of beams for beam search. 1 means no beam search.
229
- max_length (int): The maximum length of the sequence to be generated.
230
- min_length (int): The minimum length of the sequence to be generated.
231
- top_p (float): The cumulative probability for nucleus sampling.
232
- repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
233
- num_captions (int): Number of captions to be generated for each image.
234
- Returns:
235
- captions (list): A list of strings of length batch_size * num_captions.
236
- """
237
- with self.maybe_autocast():
238
- image_embeds = samples["image"]
239
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
240
- self.device
241
- )
242
-
243
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
244
- query_output = self.Qformer.bert(
245
- query_embeds=query_tokens,
246
- encoder_hidden_states=image_embeds,
247
- encoder_attention_mask=image_atts,
248
- return_dict=True,
249
- )
250
-
251
- inputs_mistral = self.opt_proj(query_output.last_hidden_state)
252
- atts_mistral = torch.ones(inputs_mistral.size()[:-1], dtype=torch.long).to(self.device)
253
-
254
- label = samples["text_input"]
255
- name = samples['name']
256
- text = samples['prompt']
257
- # text = ['' for i in range(len(label))]
258
- mistral_tokens = self.mistral_tokenizer(
259
- text,
260
- return_tensors="pt",
261
- padding="longest",
262
- truncation=True,
263
- max_length=self.max_txt_len,
264
- ).to(self.device)
265
- # inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
266
- inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
267
- inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
268
- attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
269
- # if name[0] == 'Pin':
270
- # torch.save(inputs_embeds, '/cluster/home/wenkai/LAVIS/output/inputs_embeds.pt')
271
- # torch.save(attention_mask, '/cluster/home/wenkai/LAVIS/output/attention_mask.pt')
272
-
273
- # self.get_eval = False
274
- #'''
275
- #num_txt = 15
276
- #return_num_txt = 10
277
- with torch.no_grad():
278
- outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=min_length,
279
- max_new_tokens=max_length, temperature=temperature, return_dict_in_generate=True,
280
- output_scores=True,
281
- repetition_penalty=repetition_penalty, num_beams=num_beams,
282
- length_penalty=length_penalty, num_return_sequences=num_captions,
283
- eos_token_id=self.eos_token_id)
284
- output_text = self.mistral_tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)
285
- '''
286
- num_txt = 5
287
- return_num_txt = 1
288
- with torch.no_grad():
289
- outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
290
- max_length=96,temperature=1.,return_dict_in_generate=True, output_scores=True,
291
- repetition_penalty=1., num_beams=num_txt,
292
- length_penalty=1, num_return_sequences=return_num_txt,eos_token_id=self.eos_token_id)
293
- output_text = self.mistral_tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)
294
- '''
295
- probs = F.softmax(outputs['sequences_scores'])
296
- # print(output_text)
297
- output_text = [x.replace('\n', '').strip() for x in output_text]
298
-
299
- output_text_ = []
300
- for i in range(len(label)):
301
- # output_text_.append(';'.join(output_text[i*return_num_txt:(i+1)*return_num_txt]))
302
- output_text_.append(process_text(output_text[i * num_captions:(i + 1) * num_captions],
303
- probs[i * num_captions:(i + 1) * num_captions]))
304
- #output_text_ = ['; '.join(list(set([i.strip() for i in x.split(';')]))) for x in output_text_]
305
- # with open('/cluster/home/wenkai/LAVIS/output/mf_bp_cc/output_test_mf_exp_493552.txt', 'a+', encoding="utf-8") as f:
306
- # for i in range(len(label)):
307
- # f.write(name[i] + "|" +output_text_[i]+"|"+label[i]+'\n')
308
- return output_text_
309
-
310
-
311
- def predict_answers(
312
- self,
313
- samples,
314
- num_beams=5,
315
- inference_method="generate",
316
- max_len=10,
317
- min_len=1,
318
- num_ans_candidates=128,
319
- answer_list=None,
320
- prompt="",
321
- length_penalty=0,
322
- **kwargs
323
- ):
324
- image_embeds = samples["image"]
325
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
326
- self.device
327
- )
328
-
329
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
330
- query_output = self.Qformer.bert(
331
- query_embeds=query_tokens,
332
- encoder_hidden_states=image_embeds,
333
- encoder_attention_mask=image_atts,
334
- return_dict=True,
335
- )
336
-
337
- inputs_mistral = self.opt_proj(query_output.last_hidden_state)
338
- atts_mistral = torch.ones(inputs_mistral.size()[:-1], dtype=torch.long).to(self.device)
339
-
340
- label = samples["text_input"]
341
- name = samples['name']
342
- text = samples['prompt']
343
- # text = ['' for i in range(len(label))]
344
- mistral_tokens = self.mistral_tokenizer(
345
- text,
346
- return_tensors="pt",
347
- padding="longest",
348
- truncation=True,
349
- max_length=self.max_txt_len,
350
- ).to(self.device)
351
- # inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
352
- inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
353
- inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
354
- attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
355
- # if name[0] == 'Pin':
356
- # torch.save(inputs_embeds, '/cluster/home/wenkai/LAVIS/output/inputs_embeds.pt')
357
- # torch.save(attention_mask, '/cluster/home/wenkai/LAVIS/output/attention_mask.pt')
358
-
359
- # self.get_eval = False
360
- # '''
361
- # num_txt = 15
362
- # return_num_txt = 10
363
- num_txt = 15
364
- return_num_txt = 10
365
- with torch.no_grad():
366
- outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
367
- min_length=1,
368
- max_length=32, temperature=1., return_dict_in_generate=True,
369
- output_scores=True,
370
- repetition_penalty=1., num_beams=num_txt,
371
- length_penalty=0., num_return_sequences=return_num_txt,
372
- eos_token_id=self.eos_token_id)
373
- output_text = self.mistral_tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)
374
- '''
375
- num_txt = 5
376
- return_num_txt = 1
377
- with torch.no_grad():
378
- outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
379
- max_length=96,temperature=1.,return_dict_in_generate=True, output_scores=True,
380
- repetition_penalty=1., num_beams=num_txt,
381
- length_penalty=1, num_return_sequences=return_num_txt,eos_token_id=self.eos_token_id)
382
- output_text = self.mistral_tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)
383
- '''
384
- probs = F.softmax(outputs['sequences_scores'])
385
- # print(output_text)
386
- output_text = [x.replace('\n', '').strip() for x in output_text]
387
-
388
- output_text_ = []
389
- for i in range(len(label)):
390
- # output_text_.append(';'.join(output_text[i*return_num_txt:(i+1)*return_num_txt]))
391
- output_text_.append(process_text(output_text[i * return_num_txt:(i + 1) * return_num_txt],
392
- probs[i * return_num_txt:(i + 1) * return_num_txt]))
393
- return output_text_
394
-
395
- def _lemmatize(self, answers):
396
- def apply(answer):
397
- doc = self.lemmatizer(answer)
398
-
399
- words = []
400
- for token in doc:
401
- if token.pos_ in ["NOUN", "VERB"]:
402
- words.append(token.lemma_)
403
- else:
404
- words.append(token.text)
405
- answer = " ".join(words)
406
-
407
- return answer
408
-
409
- return [apply(answer) for answer in answers]
410
-
411
- @property
412
- def lemmatizer(self):
413
- if self._lemmatizer is None:
414
- try:
415
- import spacy
416
-
417
- self._lemmatizer = spacy.load("en_core_web_sm")
418
- except ImportError:
419
- logging.error(
420
- """
421
- Please install spacy and en_core_web_sm model to apply lemmatization.
422
- python -m spacy download en_core_web_sm
423
- OR
424
- import spacy.cli
425
- spacy.cli.download("en_core_web_sm")
426
- """
427
- )
428
- exit(1)
429
-
430
- return self._lemmatizer
431
-
432
- @classmethod
433
- def from_config(cls, cfg):
434
- num_query_token = cfg.get("num_query_token")
435
-
436
- get_eval = cfg.get("get_eval", False)
437
- esm_size = cfg.get("esm_size", '650m')
438
- prompt = cfg.get("prompt", "")
439
- max_txt_len = cfg.get("max_txt_len", 128)
440
- max_protein_len = cfg.get("max_protein_len", 128)
441
-
442
- apply_lemmatizer = cfg.get("apply_lemmatizer", False)
443
-
444
- model = cls(
445
- num_query_token=num_query_token,
446
- prompt=prompt,
447
- max_txt_len=max_txt_len,
448
- max_protein_len=max_protein_len,
449
- apply_lemmatizer=apply_lemmatizer,
450
- get_eval=get_eval,
451
- esm_size=esm_size,
452
- )
453
- model.load_checkpoint_from_config(cfg)
454
-
455
- return model
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ from packaging import version
9
+
10
+ import torch
11
+ from torch.cuda.amp import autocast as autocast
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from lavis.common.registry import registry
16
+ from lavis.models.blip2_models.blip2 import Blip2Base, Blip2ProteinBase, disabled_train
17
+ from transformers import AutoTokenizer, LlamaTokenizer, MistralForCausalLM, MistralConfig
18
+ import transformers
19
+ import esm
20
+ import random
21
+ from lavis.models.base_model import FAPMConfig
22
+ from esm import pretrained
23
+
24
+
25
+ def comb(s):
26
+ s_list = [i.strip() for i in s.split(';')]
27
+ random.shuffle(s_list)
28
+ return '; '.join(s_list)
29
+
30
+
31
+ def process_text(txts, probs):
32
+ res = dict()
33
+ for txt, prob in zip(txts, probs):
34
+ txt_sep = [x.strip() for x in txt.split(';')]
35
+ for txt_sub in txt_sep:
36
+ txt_sub = txt_sub.replace('|', '')
37
+ if txt_sub not in res and txt_sub != '':
38
+ res[txt_sub] = round(prob.item(),3)
39
+ return '; '.join([str((k, v)) for k, v in res.items()])
40
+
41
+
42
+
43
+ @registry.register_model("blip2_protein_mistral")
44
+ class Blip2ProteinMistral(Blip2ProteinBase):
45
+
46
+ PRETRAINED_MODEL_CONFIG_DICT = {
47
+ "pretrain_protein_mistral7b": "configs/models/blip2/pretrain_protein_mistral7b.yaml",
48
+ }
49
+ config_class = FAPMConfig
50
+
51
+ def __init__(
52
+ self,
53
+ config,
54
+ num_query_token=32,
55
+ prompt="",
56
+ max_txt_len=128,
57
+ max_protein_len=128,
58
+ apply_lemmatizer=False,
59
+ get_eval=False,
60
+ esm_size='650m'
61
+ ):
62
+ """
63
+ apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
64
+ """
65
+ super().__init__(config)
66
+ transformers_version = version.parse(transformers.__version__)
67
+ assert transformers_version >= version.parse("4.27"), "BLIP-2 mistral requires transformers>=4.27"
68
+
69
+ self.tokenizer = self.init_tokenizer()
70
+ '''
71
+ self.ln_vision, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
72
+ if freeze_vit:
73
+ self.ln_vision = self.ln_vision.half()
74
+ self.visual_encoder = alphabet.get_batch_converter(truncation_seq_length=max_protein_len)
75
+ self.padding_idx = alphabet.padding_idx
76
+ self.vis_layers = self.ln_vision.num_layers
77
+
78
+ if freeze_vit:
79
+ for name, param in self.ln_vision.named_parameters():
80
+ param.requires_grad = False
81
+ self.ln_vision = self.ln_vision.eval()
82
+ self.ln_vision.train = disabled_train
83
+ logging.info("freeze vision encoder")
84
+ else:
85
+ for name, param in self.ln_vision.named_parameters():
86
+ if 'contact_head' in name or 'emb_layer_norm_after' in name or 'lm_head' in name:
87
+ param.requires_grad = False
88
+ '''
89
+ self.model_esm, self.alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
90
+
91
+ if esm_size == '650m':
92
+ self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, 1280)
93
+ elif esm_size == '3b':
94
+ self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, 2560)
95
+ self.Qformer.cls = None
96
+ self.Qformer.bert.embeddings.word_embeddings = None
97
+ self.Qformer.bert.embeddings.position_embeddings = None
98
+ for layer in self.Qformer.bert.encoder.layer:
99
+ layer.output = None
100
+ layer.intermediate = None
101
+
102
+ self.mistral_tokenizer = LlamaTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
103
+ # self.mistral_tokenizer = LlamaTokenizer.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", use_fast=False)
104
+ # configuration = MistralConfig()
105
+ self.mistral_tokenizer.pad_token = '<pad>'
106
+ self.mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
107
+ # self.mistral_model = MistralForCausalLM.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
108
+ # self.mistral_model = MistralForCausalLM(configuration)
109
+ for name, param in self.mistral_model.named_parameters():
110
+ param.requires_grad = False
111
+ #self.mistral_model.lm_head = self.mistral_model.lm_head.float()
112
+ #for param in self.mistral_model.lm_head.parameters():
113
+ # param.requires_grad = True
114
+
115
+ #self.eos_token_id = self.mistral_tokenizer(
116
+ # "\n", add_special_tokens=False
117
+ #).input_ids[0]
118
+ self.eos_token_id = self.mistral_tokenizer(
119
+ "\n", add_special_tokens=False
120
+ ).input_ids[1]
121
+ print(f"LLM hidden size: {self.mistral_model.config.hidden_size}")
122
+ self.opt_proj = nn.Linear(
123
+ self.Qformer.config.hidden_size, self.mistral_model.config.hidden_size
124
+ )
125
+
126
+ self.max_txt_len = max_txt_len
127
+ self.prompt = prompt
128
+ prompt_tokens = self.mistral_tokenizer(self.prompt, return_tensors="pt")
129
+ self.prompt_length = prompt_tokens.attention_mask.sum(1)
130
+
131
+ self._apply_lemmatizer = apply_lemmatizer
132
+ self._lemmatizer = None
133
+
134
+ def forward(self, samples):
135
+ '''
136
+ image = samples["image"]
137
+ image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
138
+
139
+ with self.maybe_autocast():
140
+ _, _, batch_tokens = self.visual_encoder(image)
141
+ image_embeds = self.ln_vision(batch_tokens.to(self.device), repr_layers=[self.vis_layers], return_contacts=True)["representations"][self.vis_layers].contiguous()
142
+ '''
143
+ image_embeds = samples["image"]
144
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
145
+ self.device
146
+ )
147
+
148
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
149
+ query_output = self.Qformer.bert(
150
+ query_embeds=query_tokens,
151
+ encoder_hidden_states=image_embeds,
152
+ encoder_attention_mask=image_atts,
153
+ return_dict=True,
154
+ )
155
+
156
+ inputs_mistral = self.opt_proj(query_output.last_hidden_state)
157
+
158
+ #torch.save(query_output.last_hidden_state, '/cluster/home/wenkai/LAVIS/output/mf_bp_cc/query_output_mf/{}.pt'.format(samples['name'][0]))
159
+ #torch.save(inputs_mistral, '/cluster/home/wenkai/LAVIS/output/mf_bp_cc/inputs_mistral_mf/{}.pt'.format(samples['name'][0]))
160
+
161
+ atts_mistral = torch.ones(inputs_mistral.size()[:-1], dtype=torch.long).to(self.device)
162
+
163
+ # prompt
164
+ prompt = samples["prompt"]
165
+ prompt_tokens = self.mistral_tokenizer(prompt, padding="longest", return_tensors="pt")
166
+ prompt_length = prompt_tokens.attention_mask.sum(1)
167
+
168
+ self.mistral_tokenizer.padding_side = "right"
169
+
170
+ text = [p+' '+comb(t) + "\n" for p, t in zip(prompt, samples["text_input"])]
171
+ text = [p+' '+ t + "\n" for p, t in zip(prompt, samples["text_input"])]
172
+
173
+ mistral_tokens = self.mistral_tokenizer(
174
+ text,
175
+ return_tensors="pt",
176
+ padding="longest",
177
+ truncation=True,
178
+ max_length=self.max_txt_len,
179
+ ).to(self.device)
180
+
181
+ targets = mistral_tokens.input_ids.masked_fill(
182
+ mistral_tokens.input_ids == self.mistral_tokenizer.pad_token_id, -100
183
+ )
184
+
185
+ for i, pl in enumerate(prompt_length):
186
+ targets[i, :pl] = -100 # do not apply loss to the prompt
187
+ #print(prompt_tokens, '\n', mistral_tokens, '\n', prompt_length)
188
+
189
+ #if self.prompt:
190
+ # targets[:, : self.prompt_length] = -100 # do not apply loss to the prompt
191
+
192
+ empty_targets = (
193
+ torch.ones(atts_mistral.size(), dtype=torch.long).to(self.device).fill_(-100)
194
+ )
195
+ targets = torch.cat([empty_targets, targets], dim=1)
196
+
197
+ #inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
198
+ inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
199
+ inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
200
+ attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
201
+
202
+ with self.maybe_autocast():
203
+ outputs = self.mistral_model(
204
+ inputs_embeds=inputs_embeds,
205
+ attention_mask=attention_mask,
206
+ return_dict=True,
207
+ labels=targets,
208
+ )
209
+ loss = outputs.loss
210
+ return {"loss": loss}
211
+
212
+ @torch.no_grad()
213
+ def generate(
214
+ self,
215
+ samples,
216
+ # use_nucleus_sampling=False,
217
+ num_beams=15,
218
+ max_length=32,
219
+ min_length=1,
220
+ # top_p=0.9,
221
+ repetition_penalty=1.0,
222
+ length_penalty=0.,
223
+ num_captions=10,
224
+ temperature=1,
225
+ ):
226
+ """
227
+ Args:
228
+ samples (dict): A dictionary containing the following keys:
229
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
230
+ use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
231
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
232
+ max_length (int): The maximum length of the sequence to be generated.
233
+ min_length (int): The minimum length of the sequence to be generated.
234
+ top_p (float): The cumulative probability for nucleus sampling.
235
+ repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
236
+ num_captions (int): Number of captions to be generated for each image.
237
+ Returns:
238
+ captions (list): A list of strings of length batch_size * num_captions.
239
+ """
240
+ with self.maybe_autocast():
241
+ image_embeds = samples["image"]
242
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
243
+ self.device
244
+ )
245
+
246
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
247
+ query_output = self.Qformer.bert(
248
+ query_embeds=query_tokens,
249
+ encoder_hidden_states=image_embeds,
250
+ encoder_attention_mask=image_atts,
251
+ return_dict=True,
252
+ )
253
+
254
+ inputs_mistral = self.opt_proj(query_output.last_hidden_state)
255
+ atts_mistral = torch.ones(inputs_mistral.size()[:-1], dtype=torch.long).to(self.device)
256
+
257
+ label = samples["text_input"]
258
+ name = samples['name']
259
+ text = samples['prompt']
260
+ # text = ['' for i in range(len(label))]
261
+ mistral_tokens = self.mistral_tokenizer(
262
+ text,
263
+ return_tensors="pt",
264
+ padding="longest",
265
+ truncation=True,
266
+ max_length=self.max_txt_len,
267
+ ).to(self.device)
268
+ # inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
269
+ inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
270
+ inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
271
+ attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
272
+ # if name[0] == 'Pin':
273
+ # torch.save(inputs_embeds, '/cluster/home/wenkai/LAVIS/output/inputs_embeds.pt')
274
+ # torch.save(attention_mask, '/cluster/home/wenkai/LAVIS/output/attention_mask.pt')
275
+
276
+ # self.get_eval = False
277
+ #'''
278
+ #num_txt = 15
279
+ #return_num_txt = 10
280
+ with torch.no_grad():
281
+ outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=min_length,
282
+ max_new_tokens=max_length, temperature=temperature, return_dict_in_generate=True,
283
+ output_scores=True,
284
+ repetition_penalty=repetition_penalty, num_beams=num_beams,
285
+ length_penalty=length_penalty, num_return_sequences=num_captions,
286
+ eos_token_id=self.eos_token_id)
287
+ output_text = self.mistral_tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)
288
+ '''
289
+ num_txt = 5
290
+ return_num_txt = 1
291
+ with torch.no_grad():
292
+ outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
293
+ max_length=96,temperature=1.,return_dict_in_generate=True, output_scores=True,
294
+ repetition_penalty=1., num_beams=num_txt,
295
+ length_penalty=1, num_return_sequences=return_num_txt,eos_token_id=self.eos_token_id)
296
+ output_text = self.mistral_tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)
297
+ '''
298
+ probs = F.softmax(outputs['sequences_scores'])
299
+ # print(output_text)
300
+ output_text = [x.replace('\n', '').strip() for x in output_text]
301
+
302
+ output_text_ = []
303
+ for i in range(len(label)):
304
+ # output_text_.append(';'.join(output_text[i*return_num_txt:(i+1)*return_num_txt]))
305
+ output_text_.append(process_text(output_text[i * num_captions:(i + 1) * num_captions],
306
+ probs[i * num_captions:(i + 1) * num_captions]))
307
+ #output_text_ = ['; '.join(list(set([i.strip() for i in x.split(';')]))) for x in output_text_]
308
+ # with open('/cluster/home/wenkai/LAVIS/output/mf_bp_cc/output_test_mf_exp_493552.txt', 'a+', encoding="utf-8") as f:
309
+ # for i in range(len(label)):
310
+ # f.write(name[i] + "|" +output_text_[i]+"|"+label[i]+'\n')
311
+ return output_text_
312
+
313
+
314
+ def predict_answers(
315
+ self,
316
+ samples,
317
+ num_beams=5,
318
+ inference_method="generate",
319
+ max_len=10,
320
+ min_len=1,
321
+ num_ans_candidates=128,
322
+ answer_list=None,
323
+ prompt="",
324
+ length_penalty=0,
325
+ **kwargs
326
+ ):
327
+ image_embeds = samples["image"]
328
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
329
+ self.device
330
+ )
331
+
332
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
333
+ query_output = self.Qformer.bert(
334
+ query_embeds=query_tokens,
335
+ encoder_hidden_states=image_embeds,
336
+ encoder_attention_mask=image_atts,
337
+ return_dict=True,
338
+ )
339
+
340
+ inputs_mistral = self.opt_proj(query_output.last_hidden_state)
341
+ atts_mistral = torch.ones(inputs_mistral.size()[:-1], dtype=torch.long).to(self.device)
342
+
343
+ label = samples["text_input"]
344
+ name = samples['name']
345
+ text = samples['prompt']
346
+ # text = ['' for i in range(len(label))]
347
+ mistral_tokens = self.mistral_tokenizer(
348
+ text,
349
+ return_tensors="pt",
350
+ padding="longest",
351
+ truncation=True,
352
+ max_length=self.max_txt_len,
353
+ ).to(self.device)
354
+ # inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
355
+ inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
356
+ inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
357
+ attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
358
+ # if name[0] == 'Pin':
359
+ # torch.save(inputs_embeds, '/cluster/home/wenkai/LAVIS/output/inputs_embeds.pt')
360
+ # torch.save(attention_mask, '/cluster/home/wenkai/LAVIS/output/attention_mask.pt')
361
+
362
+ # self.get_eval = False
363
+ # '''
364
+ # num_txt = 15
365
+ # return_num_txt = 10
366
+ num_txt = 15
367
+ return_num_txt = 10
368
+ with torch.no_grad():
369
+ outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
370
+ min_length=1,
371
+ max_length=32, temperature=1., return_dict_in_generate=True,
372
+ output_scores=True,
373
+ repetition_penalty=1., num_beams=num_txt,
374
+ length_penalty=0., num_return_sequences=return_num_txt,
375
+ eos_token_id=self.eos_token_id)
376
+ output_text = self.mistral_tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)
377
+ '''
378
+ num_txt = 5
379
+ return_num_txt = 1
380
+ with torch.no_grad():
381
+ outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
382
+ max_length=96,temperature=1.,return_dict_in_generate=True, output_scores=True,
383
+ repetition_penalty=1., num_beams=num_txt,
384
+ length_penalty=1, num_return_sequences=return_num_txt,eos_token_id=self.eos_token_id)
385
+ output_text = self.mistral_tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)
386
+ '''
387
+ probs = F.softmax(outputs['sequences_scores'])
388
+ # print(output_text)
389
+ output_text = [x.replace('\n', '').strip() for x in output_text]
390
+
391
+ output_text_ = []
392
+ for i in range(len(label)):
393
+ # output_text_.append(';'.join(output_text[i*return_num_txt:(i+1)*return_num_txt]))
394
+ output_text_.append(process_text(output_text[i * return_num_txt:(i + 1) * return_num_txt],
395
+ probs[i * return_num_txt:(i + 1) * return_num_txt]))
396
+ return output_text_
397
+
398
+ def _lemmatize(self, answers):
399
+ def apply(answer):
400
+ doc = self.lemmatizer(answer)
401
+
402
+ words = []
403
+ for token in doc:
404
+ if token.pos_ in ["NOUN", "VERB"]:
405
+ words.append(token.lemma_)
406
+ else:
407
+ words.append(token.text)
408
+ answer = " ".join(words)
409
+
410
+ return answer
411
+
412
+ return [apply(answer) for answer in answers]
413
+
414
+ @property
415
+ def lemmatizer(self):
416
+ if self._lemmatizer is None:
417
+ try:
418
+ import spacy
419
+
420
+ self._lemmatizer = spacy.load("en_core_web_sm")
421
+ except ImportError:
422
+ logging.error(
423
+ """
424
+ Please install spacy and en_core_web_sm model to apply lemmatization.
425
+ python -m spacy download en_core_web_sm
426
+ OR
427
+ import spacy.cli
428
+ spacy.cli.download("en_core_web_sm")
429
+ """
430
+ )
431
+ exit(1)
432
+
433
+ return self._lemmatizer
434
+
435
+ @classmethod
436
+ def from_config(cls, cfg):
437
+ num_query_token = cfg.get("num_query_token")
438
+
439
+ get_eval = cfg.get("get_eval", False)
440
+ esm_size = cfg.get("esm_size", '650m')
441
+ prompt = cfg.get("prompt", "")
442
+ max_txt_len = cfg.get("max_txt_len", 128)
443
+ max_protein_len = cfg.get("max_protein_len", 128)
444
+
445
+ apply_lemmatizer = cfg.get("apply_lemmatizer", False)
446
+
447
+ model = cls(
448
+ num_query_token=num_query_token,
449
+ prompt=prompt,
450
+ max_txt_len=max_txt_len,
451
+ max_protein_len=max_protein_len,
452
+ apply_lemmatizer=apply_lemmatizer,
453
+ get_eval=get_eval,
454
+ esm_size=esm_size,
455
+ )
456
+ model.load_checkpoint_from_config(cfg)
457
+
458
+ return model