sanjanatule commited on
Commit
ea37b8e
1 Parent(s): 2b5d0a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -48,6 +48,8 @@ resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_locat
48
  def model_generate_ans(img=None,img_audio=None,val_q=None):
49
 
50
  max_generate_length = 100
 
 
51
  with torch.no_grad():
52
 
53
  # image
@@ -60,33 +62,38 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
60
  img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
61
  img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
62
 
 
 
 
63
  # audio
64
  if img_audio is not None:
65
- audio_result = audio_model.transcribe(audio)
66
  audio_text = ''
67
  for seg in audio_result['segments']:
68
  audio_text += seg['text']
69
  audio_text = audio_text.strip()
70
  audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
71
  audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
 
72
 
73
  # text question
74
- if val_q is not None:
75
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
76
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
77
-
78
- val_combined_embeds = []
79
- if img is not None:
80
- #val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
81
- val_combined_embeds.append(val_image_embeds)
82
- val_combined_embeds.append(img_token_embeds)
83
- if img_audio is not None:
84
- #val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
85
- val_combined_embeds.append(audio_embeds)
86
- if val_q is not None:
87
- #val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
88
  val_combined_embeds.append(val_q_embeds)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
91
 
92
  #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
 
48
  def model_generate_ans(img=None,img_audio=None,val_q=None):
49
 
50
  max_generate_length = 100
51
+ val_combined_embeds = []
52
+
53
  with torch.no_grad():
54
 
55
  # image
 
62
  img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
63
  img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
64
 
65
+ val_combined_embeds.append(val_image_embeds)
66
+ val_combined_embeds.append(img_token_embeds)
67
+
68
  # audio
69
  if img_audio is not None:
70
+ audio_result = audio_model.transcribe(img_audio)
71
  audio_text = ''
72
  for seg in audio_result['segments']:
73
  audio_text += seg['text']
74
  audio_text = audio_text.strip()
75
  audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
76
  audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
77
+ val_combined_embeds.append(audio_embeds)
78
 
79
  # text question
80
+ if len(val_q) != 0:
81
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
82
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
83
  val_combined_embeds.append(val_q_embeds)
84
 
85
+ # val_combined_embeds = []
86
+ # if img is not None:
87
+ # #val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
88
+ # val_combined_embeds.append(val_image_embeds)
89
+ # val_combined_embeds.append(img_token_embeds)
90
+ # if img_audio is not None:
91
+ # #val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
92
+ # val_combined_embeds.append(audio_embeds)
93
+ # if len(val_q) != 0:
94
+ # #val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
95
+ # val_combined_embeds.append(val_q_embeds)
96
+
97
  val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
98
 
99
  #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560