sanjanatule commited on
Commit
429d535
1 Parent(s): ea37b8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -22
app.py CHANGED
@@ -82,32 +82,25 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
82
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
83
  val_combined_embeds.append(val_q_embeds)
84
 
85
- # val_combined_embeds = []
86
- # if img is not None:
87
- # #val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
88
- # val_combined_embeds.append(val_image_embeds)
89
- # val_combined_embeds.append(img_token_embeds)
90
- # if img_audio is not None:
91
- # #val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
92
- # val_combined_embeds.append(audio_embeds)
93
- # if len(val_q) != 0:
94
- # #val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
95
- # val_combined_embeds.append(val_q_embeds)
96
 
97
  val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
98
-
99
- #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
100
- predicted_caption = torch.full((1,max_generate_length),50256).to(device)
 
 
101
 
102
- for g in range(max_generate_length):
103
- phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
104
- predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
105
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
106
- predicted_caption[:,g] = predicted_word_token.view(1,-1)
107
- next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
108
- val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
109
 
110
- predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
 
 
111
 
112
  return predicted_captions_decoded
113
 
 
82
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
83
  val_combined_embeds.append(val_q_embeds)
84
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
87
+ predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds,
88
+ max_new_tokens=max_generate_length,
89
+ return_dict_in_generate = True)
90
+
91
+ # predicted_caption = torch.full((1,max_generate_length),50256).to(device)
92
 
93
+ # for g in range(max_generate_length):
94
+ # phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
95
+ # predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
96
+ # predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
97
+ # predicted_caption[:,g] = predicted_word_token.view(1,-1)
98
+ # next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
99
+ # val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
100
 
101
+ #predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
102
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
103
+ predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "")
104
 
105
  return predicted_captions_decoded
106