Spaces:
Sleeping
Sleeping
Commit
•
31a9142
1
Parent(s):
e9d7857
Update app.py
Browse files
app.py
CHANGED
@@ -45,32 +45,34 @@ resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_locat
|
|
45 |
def model_generate_ans(img,val_q):
|
46 |
|
47 |
max_generate_length = 30
|
48 |
-
|
49 |
-
# image
|
50 |
-
image_processed = processor(images=img, return_tensors="pt").to(device)
|
51 |
-
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
|
52 |
-
val_image_embeds = projection(clip_val_outputs)
|
53 |
-
val_image_embeds = resblock(val_image_embeds).to(torch.float16)
|
54 |
-
|
55 |
-
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
56 |
-
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
|
57 |
|
58 |
-
|
59 |
-
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
60 |
|
61 |
-
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
|
71 |
-
val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
|
72 |
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
return predicted_captions_decoded
|
76 |
|
|
|
45 |
def model_generate_ans(img,val_q):
|
46 |
|
47 |
max_generate_length = 30
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
with torch.no_grad():
|
|
|
50 |
|
51 |
+
# image
|
52 |
+
image_processed = processor(images=img, return_tensors="pt").to(device)
|
53 |
+
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
|
54 |
+
val_image_embeds = projection(clip_val_outputs)
|
55 |
+
val_image_embeds = resblock(val_image_embeds).to(torch.float16)
|
56 |
|
57 |
+
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
58 |
+
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
|
59 |
+
|
60 |
+
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
61 |
+
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
62 |
+
|
63 |
+
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
|
|
|
|
64 |
|
65 |
+
predicted_caption = torch.full((1,max_generate_length),50256)
|
66 |
+
|
67 |
+
for g in range(max_generate_length):
|
68 |
+
phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
|
69 |
+
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
70 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
71 |
+
predicted_caption[:,g] = predicted_word_token.view(1,-1)
|
72 |
+
next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
|
73 |
+
val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
|
74 |
+
|
75 |
+
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
|
76 |
|
77 |
return predicted_captions_decoded
|
78 |
|