sanjanatule commited on
Commit
8625fcd
1 Parent(s): fecb6e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -35,17 +35,17 @@ def model_generate_ans(img,val_q):
35
  val_image_embeds = projection(clip_val_outputs).to(torch.float16)
36
 
37
  img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
38
- img_token_embeds = peft_model.model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
39
 
40
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
41
- val_q_embeds = peft_model.model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
42
 
43
  val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
44
 
45
  predicted_caption = torch.full((1,max_generate_length),50256)
46
 
47
  for g in range(max_generate_length):
48
- phi_output_logits = peft_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
49
  predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
50
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
51
  predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
 
35
  val_image_embeds = projection(clip_val_outputs).to(torch.float16)
36
 
37
  img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
38
+ img_token_embeds = merged_model.model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
39
 
40
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
41
+ val_q_embeds = merged_model.model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
42
 
43
  val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
44
 
45
  predicted_caption = torch.full((1,max_generate_length),50256)
46
 
47
  for g in range(max_generate_length):
48
+ phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
49
  predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
50
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
51
  predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)