sanjanatule commited on
Commit
4e448ea
1 Parent(s): 79438f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -75,14 +75,19 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
75
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
76
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
77
 
78
- val_combined_embeds = torch.empty(5,3,0)
79
-
80
  if image:
81
- val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
 
 
82
  if img_audio:
83
- val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
 
84
  if val_q:
85
- val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
 
 
 
86
 
87
  #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
88
  predicted_caption = torch.full((1,max_generate_length),50256).to(device)
 
75
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
76
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
77
 
78
+ val_combined_embeds = []
 
79
  if image:
80
+ #val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
81
+ val_combined_embeds.append(val_image_embeds)
82
+ val_combined_embeds.append(img_token_embeds)
83
  if img_audio:
84
+ #val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
85
+ val_combined_embeds.append(audio_embeds)
86
  if val_q:
87
+ #val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
88
+ val_combined_embeds.append(val_q_embeds)
89
+
90
+ val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
91
 
92
  #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
93
  predicted_caption = torch.full((1,max_generate_length),50256).to(device)