sanjanatule commited on
Commit
31a9142
1 Parent(s): e9d7857

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -22
app.py CHANGED
@@ -45,32 +45,34 @@ resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_locat
45
  def model_generate_ans(img,val_q):
46
 
47
  max_generate_length = 30
48
-
49
- # image
50
- image_processed = processor(images=img, return_tensors="pt").to(device)
51
- clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
52
- val_image_embeds = projection(clip_val_outputs)
53
- val_image_embeds = resblock(val_image_embeds).to(torch.float16)
54
-
55
- img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
56
- img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
57
 
58
- val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
59
- val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
60
 
61
- val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
 
 
 
 
62
 
63
- predicted_caption = torch.full((1,max_generate_length),50256)
64
-
65
- for g in range(max_generate_length):
66
- phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
67
- predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
68
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
69
- predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
70
- next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
71
- val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
72
 
73
- predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
 
 
 
 
 
 
 
 
 
 
74
 
75
  return predicted_captions_decoded
76
 
 
45
  def model_generate_ans(img,val_q):
46
 
47
  max_generate_length = 30
 
 
 
 
 
 
 
 
 
48
 
49
+ with torch.no_grad():
 
50
 
51
+ # image
52
+ image_processed = processor(images=img, return_tensors="pt").to(device)
53
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
54
+ val_image_embeds = projection(clip_val_outputs)
55
+ val_image_embeds = resblock(val_image_embeds).to(torch.float16)
56
 
57
+ img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
58
+ img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
59
+
60
+ val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
61
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
62
+
63
+ val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
 
 
64
 
65
+ predicted_caption = torch.full((1,max_generate_length),50256)
66
+
67
+ for g in range(max_generate_length):
68
+ phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
69
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
70
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
71
+ predicted_caption[:,g] = predicted_word_token.view(1,-1)
72
+ next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
73
+ val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
74
+
75
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
76
 
77
  return predicted_captions_decoded
78