Spaces:
Sleeping
Sleeping
sanjanatule
commited on
Commit
•
4e448ea
1
Parent(s):
79438f3
Update app.py
Browse files
app.py
CHANGED
@@ -75,14 +75,19 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
|
|
75 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
|
76 |
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
77 |
|
78 |
-
val_combined_embeds =
|
79 |
-
|
80 |
if image:
|
81 |
-
val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
|
|
|
|
|
82 |
if img_audio:
|
83 |
-
val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
|
|
|
84 |
if val_q:
|
85 |
-
val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
|
|
|
|
|
|
|
86 |
|
87 |
#val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
88 |
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
|
|
|
75 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
|
76 |
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
77 |
|
78 |
+
val_combined_embeds = []
|
|
|
79 |
if image:
|
80 |
+
#val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
|
81 |
+
val_combined_embeds.append(val_image_embeds)
|
82 |
+
val_combined_embeds.append(img_token_embeds)
|
83 |
if img_audio:
|
84 |
+
#val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
|
85 |
+
val_combined_embeds.append(audio_embeds)
|
86 |
if val_q:
|
87 |
+
#val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
|
88 |
+
val_combined_embeds.append(val_q_embeds)
|
89 |
+
|
90 |
+
val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
|
91 |
|
92 |
#val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
93 |
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
|