Spaces:
Sleeping
Sleeping
Commit
•
429d535
1
Parent(s):
ea37b8e
Update app.py
Browse files
app.py
CHANGED
@@ -82,32 +82,25 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
|
|
82 |
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
83 |
val_combined_embeds.append(val_q_embeds)
|
84 |
|
85 |
-
# val_combined_embeds = []
|
86 |
-
# if img is not None:
|
87 |
-
# #val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
|
88 |
-
# val_combined_embeds.append(val_image_embeds)
|
89 |
-
# val_combined_embeds.append(img_token_embeds)
|
90 |
-
# if img_audio is not None:
|
91 |
-
# #val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
|
92 |
-
# val_combined_embeds.append(audio_embeds)
|
93 |
-
# if len(val_q) != 0:
|
94 |
-
# #val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
|
95 |
-
# val_combined_embeds.append(val_q_embeds)
|
96 |
|
97 |
val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
101 |
|
102 |
-
for g in range(max_generate_length):
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
|
110 |
-
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
|
|
|
|
|
111 |
|
112 |
return predicted_captions_decoded
|
113 |
|
|
|
82 |
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
83 |
val_combined_embeds.append(val_q_embeds)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
|
87 |
+
predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds,
|
88 |
+
max_new_tokens=max_generate_length,
|
89 |
+
return_dict_in_generate = True)
|
90 |
+
|
91 |
+
# predicted_caption = torch.full((1,max_generate_length),50256).to(device)
|
92 |
|
93 |
+
# for g in range(max_generate_length):
|
94 |
+
# phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
|
95 |
+
# predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
96 |
+
# predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
97 |
+
# predicted_caption[:,g] = predicted_word_token.view(1,-1)
|
98 |
+
# next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
|
99 |
+
# val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
|
100 |
|
101 |
+
#predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
|
102 |
+
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
|
103 |
+
predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "")
|
104 |
|
105 |
return predicted_captions_decoded
|
106 |
|