wenkai commited on
Commit
ea37187
·
verified ·
1 Parent(s): 6b16660

Update lavis/models/protein_models/protein_function_opt.py

Browse files
lavis/models/protein_models/protein_function_opt.py CHANGED
@@ -98,26 +98,15 @@ class Blip2ProteinMistral(Blip2ProteinBase):
98
 
99
  self.mistral_tokenizer = LlamaTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
100
  # self.mistral_tokenizer = LlamaTokenizer.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", use_fast=False)
101
- # configuration = MistralConfig()
102
  self.mistral_tokenizer.pad_token = '<pad>'
103
- self.mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
104
- # self.mistral_model = MistralForCausalLM.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
105
- # self.mistral_model = MistralForCausalLM(configuration)
106
- for name, param in self.mistral_model.named_parameters():
107
- param.requires_grad = False
108
- #self.mistral_model.lm_head = self.mistral_model.lm_head.float()
109
- #for param in self.mistral_model.lm_head.parameters():
110
- # param.requires_grad = True
111
-
112
- #self.eos_token_id = self.mistral_tokenizer(
113
- # "\n", add_special_tokens=False
114
- #).input_ids[0]
115
  self.eos_token_id = self.mistral_tokenizer(
116
  "\n", add_special_tokens=False
117
  ).input_ids[1]
118
- print(f"LLM hidden size: {self.mistral_model.config.hidden_size}")
119
  self.opt_proj = nn.Linear(
120
- self.Qformer.config.hidden_size, self.mistral_model.config.hidden_size
121
  )
122
 
123
  self.max_txt_len = max_txt_len
@@ -191,7 +180,6 @@ class Blip2ProteinMistral(Blip2ProteinBase):
191
  )
192
  targets = torch.cat([empty_targets, targets], dim=1)
193
 
194
- #inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
195
  inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
196
  inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
197
  attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
@@ -209,6 +197,7 @@ class Blip2ProteinMistral(Blip2ProteinBase):
209
  @torch.no_grad()
210
  def generate(
211
  self,
 
212
  samples,
213
  # use_nucleus_sampling=False,
214
  num_beams=15,
@@ -262,8 +251,8 @@ class Blip2ProteinMistral(Blip2ProteinBase):
262
  truncation=True,
263
  max_length=self.max_txt_len,
264
  ).to(self.device)
265
- # inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
266
- inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
267
  inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
268
  attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
269
  # if name[0] == 'Pin':
@@ -275,7 +264,7 @@ class Blip2ProteinMistral(Blip2ProteinBase):
275
  #num_txt = 15
276
  #return_num_txt = 10
277
  with torch.no_grad():
278
- outputs = self.mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=min_length,
279
  max_new_tokens=max_length, temperature=temperature, return_dict_in_generate=True,
280
  output_scores=True,
281
  repetition_penalty=repetition_penalty, num_beams=num_beams,
 
98
 
99
  self.mistral_tokenizer = LlamaTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
100
  # self.mistral_tokenizer = LlamaTokenizer.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", use_fast=False)
 
101
  self.mistral_tokenizer.pad_token = '<pad>'
102
+ # self.mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
103
+ self.mistral_model = None
104
+
 
 
 
 
 
 
 
 
 
105
  self.eos_token_id = self.mistral_tokenizer(
106
  "\n", add_special_tokens=False
107
  ).input_ids[1]
 
108
  self.opt_proj = nn.Linear(
109
+ self.Qformer.config.hidden_size, 4096
110
  )
111
 
112
  self.max_txt_len = max_txt_len
 
180
  )
181
  targets = torch.cat([empty_targets, targets], dim=1)
182
 
 
183
  inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
184
  inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
185
  attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
 
197
  @torch.no_grad()
198
  def generate(
199
  self,
200
+ mistral_model,
201
  samples,
202
  # use_nucleus_sampling=False,
203
  num_beams=15,
 
251
  truncation=True,
252
  max_length=self.max_txt_len,
253
  ).to(self.device)
254
+
255
+ inputs_embeds = mistral_model.model.embed_tokens(mistral_tokens.input_ids)
256
  inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
257
  attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
258
  # if name[0] == 'Pin':
 
264
  #num_txt = 15
265
  #return_num_txt = 10
266
  with torch.no_grad():
267
+ outputs = mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=min_length,
268
  max_new_tokens=max_length, temperature=temperature, return_dict_in_generate=True,
269
  output_scores=True,
270
  repetition_penalty=repetition_penalty, num_beams=num_beams,