Spaces:
Runtime error
Runtime error
Commit
•
8625fcd
1
Parent(s):
fecb6e0
Update app.py
Browse files
app.py
CHANGED
@@ -35,17 +35,17 @@ def model_generate_ans(img,val_q):
|
|
35 |
val_image_embeds = projection(clip_val_outputs).to(torch.float16)
|
36 |
|
37 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
38 |
-
img_token_embeds =
|
39 |
|
40 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
41 |
-
val_q_embeds =
|
42 |
|
43 |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
44 |
|
45 |
predicted_caption = torch.full((1,max_generate_length),50256)
|
46 |
|
47 |
for g in range(max_generate_length):
|
48 |
-
phi_output_logits =
|
49 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
50 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
51 |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
|
|
|
35 |
val_image_embeds = projection(clip_val_outputs).to(torch.float16)
|
36 |
|
37 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
38 |
+
img_token_embeds = merged_model.model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
|
39 |
|
40 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
41 |
+
val_q_embeds = merged_model.model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
42 |
|
43 |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
44 |
|
45 |
predicted_caption = torch.full((1,max_generate_length),50256)
|
46 |
|
47 |
for g in range(max_generate_length):
|
48 |
+
phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
|
49 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
50 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
51 |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
|