sanjanatule commited on
Commit
fd892c6
1 Parent(s): d7298ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -23,7 +23,7 @@ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_cod
23
  # load weights
24
  model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
25
  merged_model = model_to_merge.merge_and_unload()
26
- projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth'))
27
 
28
  def model_generate_ans(img,val_q):
29
 
@@ -48,7 +48,7 @@ def model_generate_ans(img,val_q):
48
  phi_output_logits = peft_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
49
  predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
50
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
51
- predicted_caption[:,g] = predicted_word_token.view(1,-1).to('cpu')
52
 
53
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
54
 
 
23
  # load weights
24
  model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
25
  merged_model = model_to_merge.merge_and_unload()
26
+ projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
27
 
28
  def model_generate_ans(img,val_q):
29
 
 
48
  phi_output_logits = peft_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
49
  predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
50
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
51
+ predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
52
 
53
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
54